{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cayt5nCXb3WG"
      },
      "source": [
        "##### Copyright 2022 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DYL3CXHRb9-f"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VYDmsvURYZjz"
      },
      "source": [
        "# Object detection with Model Garden\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/object_detection\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/object_detection.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/object_detection.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/object_detection.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "69aQq_PXcUvL"
      },
      "source": [
        "This tutorial fine-tunes a [RetinaNet](https://arxiv.org/abs/1708.02002) with ResNet-50 as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models) to detect three different Blood Cells in [BCCD](https://public.roboflow.com/object-detection/bccd) dataset. The RetinaNet is pretrained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017\n",
        "\n",
        "[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
        "\n",
        "This tutorial demonstrates how to:\n",
        "\n",
        "1. Use models from the Tensorflow Model Garden(TFM) package.\n",
        "2. Fine-tune a pre-trained RetinanNet with ResNet-50 as backbone for object detection.\n",
        "3. Export the tuned RetinaNet model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IeSHlZyUZl6f"
      },
      "source": [
        "## Install necessary dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pip0LHj3ZqgL"
      },
      "outputs": [],
      "source": [
        "!pip install -U -q \"tf-models-official\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H3kS7Y0sZsIj"
      },
      "source": [
        "## Import required libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hFdVelJ2YbQz"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import io\n",
        "import pprint\n",
        "import tempfile\n",
        "import matplotlib\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from PIL import Image\n",
        "from six import BytesIO\n",
        "from IPython import display\n",
        "from urllib.request import urlopen"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TF77J-iMZn_u"
      },
      "source": [
        "## Import required libraries from tensorflow models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iT27_SOTY1Dz"
      },
      "outputs": [],
      "source": [
        "import orbit\n",
        "import tensorflow_models as tfm\n",
        "\n",
        "from official.core import exp_factory\n",
        "from official.core import config_definitions as cfg\n",
        "from official.vision.serving import export_saved_model_lib\n",
        "from official.vision.ops.preprocess_ops import normalize_image\n",
        "from official.vision.ops.preprocess_ops import resize_and_crop_image\n",
        "from official.vision.utils.object_detection import visualization_utils\n",
        "from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder\n",
        "\n",
        "pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
        "print(tf.__version__) # Check the version of tensorflow used\n",
        "\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WGbMG8cpZyKa"
      },
      "source": [
        "## Custom dataset preparation for object detection\n",
        "\n",
        "Models in official repository(of model-garden) requires data in a TFRecords format.\n",
        "\n",
        "\n",
        "Please check [this resource](https://www.tensorflow.org/tutorials/load_data/tfrecord) to learn more about TFRecords data format.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uq5hcbJ8Z4th"
      },
      "source": [
        "### Upload your custom data in drive or local disk of the notebook and unzip the data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rDixpoqoY3Za"
      },
      "outputs": [],
      "source": [
        "!curl -L 'https://public.roboflow.com/ds/ZpYLqHeT0W?key=ZXfZLRnhsc' > './BCCD.v1-bccd.coco.zip'\n",
        "!unzip -q -o './BCCD.v1-bccd.coco.zip' -d './BCC.v1-bccd.coco/'\n",
        "!rm './BCCD.v1-bccd.coco.zip'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GI1h9UChZ8cC"
      },
      "source": [
        "### CLI command to convert data(train data)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x_8cmB82Y65O"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "\n",
        "TRAIN_DATA_DIR='./BCC.v1-bccd.coco/train'\n",
        "TRAIN_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/train/_annotations.coco.json'\n",
        "OUTPUT_TFRECORD_TRAIN='./bccd_coco_tfrecords/train'\n",
        "\n",
        "# Need to provide\n",
        "  # 1. image_dir: where images are present\n",
        "  # 2. object_annotations_file: where annotations are listed in json format\n",
        "  # 3. output_file_prefix: where to write output convered TFRecords files\n",
        "python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
        "  --image_dir=${TRAIN_DATA_DIR} \\\n",
        "  --object_annotations_file=${TRAIN_ANNOTATION_FILE_DIR} \\\n",
        "  --output_file_prefix=$OUTPUT_TFRECORD_TRAIN \\\n",
        "  --num_shards=1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VuwZpwUoaAKU"
      },
      "source": [
        "### CLI command to convert data(validation data)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q8mQ8prGY8kh"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "\n",
        "VALID_DATA_DIR='./BCC.v1-bccd.coco/valid'\n",
        "VALID_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/valid/_annotations.coco.json'\n",
        "OUTPUT_TFRECORD_VALID='./bccd_coco_tfrecords/valid'\n",
        "\n",
        "python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
        "  --image_dir=$VALID_DATA_DIR \\\n",
        "  --object_annotations_file=$VALID_ANNOTATION_FILE_DIR \\\n",
        "  --output_file_prefix=$OUTPUT_TFRECORD_VALID \\\n",
        "  --num_shards=1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BYGxNNAXaCW6"
      },
      "source": [
        "### CLI command to convert data(test data)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-K8qlfstY-Ua"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "\n",
        "TEST_DATA_DIR='./BCC.v1-bccd.coco/test'\n",
        "TEST_ANNOTATION_FILE_DIR='./BCC.v1-bccd.coco/test/_annotations.coco.json'\n",
        "OUTPUT_TFRECORD_TEST='./bccd_coco_tfrecords/test'\n",
        "\n",
        "python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
        "  --image_dir=$TEST_DATA_DIR \\\n",
        "  --object_annotations_file=$TEST_ANNOTATION_FILE_DIR \\\n",
        "  --output_file_prefix=$OUTPUT_TFRECORD_TEST \\\n",
        "  --num_shards=1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cW7hQEJTaEtj"
      },
      "source": [
        "## Configure the Retinanet Resnet FPN COCO model for custom dataset.\n",
        "\n",
        "Dataset used for fine tuning the checkpoint is Blood Cells Detection (BCCD)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PMGEl7iXZAAF"
      },
      "outputs": [],
      "source": [
        "train_data_input_path = './bccd_coco_tfrecords/train-00000-of-00001.tfrecord'\n",
        "valid_data_input_path = './bccd_coco_tfrecords/valid-00000-of-00001.tfrecord'\n",
        "test_data_input_path = './bccd_coco_tfrecords/test-00000-of-00001.tfrecord'\n",
        "model_dir = './trained_model/'\n",
        "export_dir ='./exported_model/'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2DJpKvdeaHF3"
      },
      "source": [
        "In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
        "\n",
        "\n",
        "Use the `retinanet_resnetfpn_coco` experiment configuration, as defined by `tfm.vision.configs.retinanet.retinanet_resnetfpn_coco`.\n",
        "\n",
        "The configuration defines an experiment to train a RetinanNet with Resnet-50 as backbone, FPN as decoder. Default Configuration is trained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017.\n",
        "\n",
        "There are also other alternative experiments available such as\n",
        "`retinanet_resnetfpn_coco`, `retinanet_spinenet_coco`, `fasterrcnn_resnetfpn_coco` and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function.\n",
        "\n",
        "We are going to fine tune the Resnet-50 backbone checkpoint which is already present in the default configuration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ie1ObPH9ZBpa"
      },
      "outputs": [],
      "source": [
        "exp_config = exp_factory.get_exp_config('retinanet_resnetfpn_coco')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LFhjFkw-alba"
      },
      "source": [
        "### Adjust the model and dataset configurations so that it works with custom dataset(in this case `BCCD`)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ej7j6dvIZDQA"
      },
      "outputs": [],
      "source": [
        "batch_size = 8\n",
        "num_classes = 3\n",
        "\n",
        "HEIGHT, WIDTH = 256, 256\n",
        "IMG_SIZE = [HEIGHT, WIDTH, 3]\n",
        "\n",
        "# Backbone config.\n",
        "exp_config.task.freeze_backbone = False\n",
        "exp_config.task.annotation_file = ''\n",
        "\n",
        "# Model config.\n",
        "exp_config.task.model.input_size = IMG_SIZE\n",
        "exp_config.task.model.num_classes = num_classes + 1\n",
        "exp_config.task.model.detection_generator.tflite_post_processing.max_classes_per_detection = exp_config.task.model.num_classes\n",
        "\n",
        "# Training data config.\n",
        "exp_config.task.train_data.input_path = train_data_input_path\n",
        "exp_config.task.train_data.dtype = 'float32'\n",
        "exp_config.task.train_data.global_batch_size = batch_size\n",
        "exp_config.task.train_data.parser.aug_scale_max = 1.0\n",
        "exp_config.task.train_data.parser.aug_scale_min = 1.0\n",
        "\n",
        "# Validation data config.\n",
        "exp_config.task.validation_data.input_path = valid_data_input_path\n",
        "exp_config.task.validation_data.dtype = 'float32'\n",
        "exp_config.task.validation_data.global_batch_size = batch_size"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ROVc1rayaqI1"
      },
      "source": [
        "### Adjust the trainer configuration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BZsCVBafZFIE"
      },
      "outputs": [],
      "source": [
        "logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
        "\n",
        "if 'GPU' in ''.join(logical_device_names):\n",
        "  print('This may be broken in Colab.')\n",
        "  device = 'GPU'\n",
        "elif 'TPU' in ''.join(logical_device_names):\n",
        "  print('This may be broken in Colab.')\n",
        "  device = 'TPU'\n",
        "else:\n",
        "  print('Running on CPU is slow, so only train for a few steps.')\n",
        "  device = 'CPU'\n",
        "\n",
        "\n",
        "train_steps = 1000\n",
        "exp_config.trainer.steps_per_loop = 100 # steps_per_loop = num_of_training_examples // train_batch_size\n",
        "\n",
        "exp_config.trainer.summary_interval = 100\n",
        "exp_config.trainer.checkpoint_interval = 100\n",
        "exp_config.trainer.validation_interval = 100\n",
        "exp_config.trainer.validation_steps =  100 # validation_steps = num_of_validation_examples // eval_batch_size\n",
        "exp_config.trainer.train_steps = train_steps\n",
        "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100\n",
        "exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
        "exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XS6cJfs2atgI"
      },
      "source": [
        "### Print the modified configuration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IvfJlqI7ZIcD"
      },
      "outputs": [],
      "source": [
        "pp.pprint(exp_config.as_dict())\n",
        "display.Javascript('google.colab.output.setIframeHeight(\"500px\");')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6o5mbpRBawbs"
      },
      "source": [
        "### Set up the distribution strategy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2NvY8QHOZKGr"
      },
      "outputs": [],
      "source": [
        "if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
        "    tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
        "\n",
        "if 'GPU' in ''.join(logical_device_names):\n",
        "  distribution_strategy = tf.distribute.MirroredStrategy()\n",
        "elif 'TPU' in ''.join(logical_device_names):\n",
        "  tf.tpu.experimental.initialize_tpu_system()\n",
        "  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
        "  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
        "else:\n",
        "  print('Warning: this will be really slow.')\n",
        "  distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
        "\n",
        "print('Done')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4wPtJgoOa33v"
      },
      "source": [
        "## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
        "\n",
        "The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ns9LAsiXZLuX"
      },
      "outputs": [],
      "source": [
        "with distribution_strategy.scope():\n",
        "  task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vTKbQxDkbArE"
      },
      "source": [
        "## Visualize a batch of the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3RIlbhj0ZNvt"
      },
      "outputs": [],
      "source": [
        "for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
        "  print()\n",
        "  print(f'images.shape: {str(images.shape):16}  images.dtype: {images.dtype!r}')\n",
        "  print(f'labels.keys: {labels.keys()}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m-QW7DoKbD8z"
      },
      "source": [
        "### Create category index dictionary to map the labels to coressponding label names."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MN0sSthbZR-s"
      },
      "outputs": [],
      "source": [
        "category_index={\n",
        "    1: {\n",
        "        'id': 1,\n",
        "        'name': 'Platelets'\n",
        "       },\n",
        "    2: {\n",
        "        'id': 2,\n",
        "        'name': 'RBC'\n",
        "       },\n",
        "    3: {\n",
        "        'id': 3,\n",
        "        'name': 'WBC'\n",
        "       }\n",
        "}\n",
        "tf_ex_decoder = TfExampleDecoder()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AcbmD1pRbGcS"
      },
      "source": [
        "### Helper function for visualizing the results from TFRecords.\n",
        "Use `visualize_boxes_and_labels_on_image_array` from `visualization_utils` to draw boudning boxes on the image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wWBeomMMZThI"
      },
      "outputs": [],
      "source": [
        "def show_batch(raw_records, num_of_examples):\n",
        "  plt.figure(figsize=(20, 20))\n",
        "  use_normalized_coordinates=True\n",
        "  min_score_thresh = 0.30\n",
        "  for i, serialized_example in enumerate(raw_records):\n",
        "    plt.subplot(1, 3, i + 1)\n",
        "    decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
        "    image = decoded_tensors['image'].numpy().astype('uint8')\n",
        "    scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))\n",
        "    visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
        "        image,\n",
        "        decoded_tensors['groundtruth_boxes'].numpy(),\n",
        "        decoded_tensors['groundtruth_classes'].numpy().astype('int'),\n",
        "        scores,\n",
        "        category_index=category_index,\n",
        "        use_normalized_coordinates=use_normalized_coordinates,\n",
        "        max_boxes_to_draw=200,\n",
        "        min_score_thresh=min_score_thresh,\n",
        "        agnostic_mode=False,\n",
        "        instance_masks=None,\n",
        "        line_thickness=4)\n",
        "\n",
        "    plt.imshow(image)\n",
        "    plt.axis('off')\n",
        "    plt.title(f'Image-{i+1}')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R3EgriELbJly"
      },
      "source": [
        "### Visualization of train data\n",
        "\n",
        "The bounding box detection has two components\n",
        "  1. Class label of the object detected (e.g.RBC)\n",
        "  2. Percentage of match between predicted and ground truth bounding boxes.\n",
        "\n",
        "**Note**: The reason of everything is 100% is because we are visualising the groundtruth."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hdrsciGIZVNO"
      },
      "outputs": [],
      "source": [
        "buffer_size = 20\n",
        "num_of_examples = 3\n",
        "\n",
        "raw_records = tf.data.TFRecordDataset(\n",
        "    exp_config.task.train_data.input_path).shuffle(\n",
        "        buffer_size=buffer_size).take(num_of_examples)\n",
        "show_batch(raw_records, num_of_examples)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IrWkJPyEbMKg"
      },
      "source": [
        "## Train and evaluate.\n",
        "\n",
        "We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check [here](https://cocodataset.org/#detection-eval) for detail explanation of how evaluation metrics for detection task is done.\n",
        "\n",
        "**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "SCjHHXvfZXX1"
      },
      "outputs": [],
      "source": [
        "model, eval_logs = tfm.core.train_lib.run_experiment(\n",
        "    distribution_strategy=distribution_strategy,\n",
        "    task=task,\n",
        "    mode='train_and_eval',\n",
        "    params=exp_config,\n",
        "    model_dir=model_dir,\n",
        "    run_post_eval=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2Gd6uHLjbPKW"
      },
      "source": [
        "## Load logs in tensorboard."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q6iDRUVqZY86"
      },
      "outputs": [],
      "source": [
        "%load_ext tensorboard\n",
        "%tensorboard --logdir './trained_model/'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AoL2MIJobReU"
      },
      "source": [
        "## Saving and exporting the trained model.\n",
        "\n",
        "The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CmOBYXdXZah4"
      },
      "outputs": [],
      "source": [
        "export_saved_model_lib.export_inference_graph(\n",
        "    input_type='image_tensor',\n",
        "    batch_size=1,\n",
        "    input_image_size=[HEIGHT, WIDTH],\n",
        "    params=exp_config,\n",
        "    checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
        "    export_dir=export_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_JhXopm8bU1g"
      },
      "source": [
        "## Inference from trained model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EbD4j1uCZcIV"
      },
      "outputs": [],
      "source": [
        "def load_image_into_numpy_array(path):\n",
        "  \"\"\"Load an image from file into a numpy array.\n",
        "\n",
        "  Puts image into numpy array to feed into tensorflow graph.\n",
        "  Note that by convention we put it into a numpy array with shape\n",
        "  (height, width, channels), where channels=3 for RGB.\n",
        "\n",
        "  Args:\n",
        "    path: the file path to the image\n",
        "\n",
        "  Returns:\n",
        "    uint8 numpy array with shape (img_height, img_width, 3)\n",
        "  \"\"\"\n",
        "  image = None\n",
        "  if(path.startswith('http')):\n",
        "    response = urlopen(path)\n",
        "    image_data = response.read()\n",
        "    image_data = BytesIO(image_data)\n",
        "    image = Image.open(image_data)\n",
        "  else:\n",
        "    image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
        "    image = Image.open(BytesIO(image_data))\n",
        "\n",
        "  (im_width, im_height) = image.size\n",
        "  return np.array(image.getdata()).reshape(\n",
        "      (1, im_height, im_width, 3)).astype(np.uint8)\n",
        "\n",
        "\n",
        "\n",
        "def build_inputs_for_object_detection(image, input_image_size):\n",
        "  \"\"\"Builds Object Detection model inputs for serving.\"\"\"\n",
        "  image, _ = resize_and_crop_image(\n",
        "      image,\n",
        "      input_image_size,\n",
        "      padded_size=input_image_size,\n",
        "      aug_scale_min=1.0,\n",
        "      aug_scale_max=1.0)\n",
        "  return image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o8bguhK_batq"
      },
      "source": [
        "### Visualize test data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sOsDhYmyZd_m"
      },
      "outputs": [],
      "source": [
        "num_of_examples = 3\n",
        "\n",
        "test_ds = tf.data.TFRecordDataset(\n",
        "    './bccd_coco_tfrecords/test-00000-of-00001.tfrecord').take(\n",
        "        num_of_examples)\n",
        "show_batch(test_ds, num_of_examples)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kcYnb1Zfbba9"
      },
      "source": [
        "### Importing SavedModel."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nQ6waz9rZfhy"
      },
      "outputs": [],
      "source": [
        "imported = tf.saved_model.load(export_dir)\n",
        "model_fn = imported.signatures['serving_default']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CtB4gfZ3bfiC"
      },
      "source": [
        "### Visualize predictions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UTSfNZ6yZhEV"
      },
      "outputs": [],
      "source": [
        "input_image_size = (HEIGHT, WIDTH)\n",
        "plt.figure(figsize=(20, 20))\n",
        "min_score_thresh = 0.30 # Change minimum score for threshold to see all bounding boxes confidences.\n",
        "\n",
        "for i, serialized_example in enumerate(test_ds):\n",
        "  plt.subplot(1, 3, i+1)\n",
        "  decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
        "  image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)\n",
        "  image = tf.expand_dims(image, axis=0)\n",
        "  image = tf.cast(image, dtype = tf.uint8)\n",
        "  image_np = image[0].numpy()\n",
        "  result = model_fn(image)\n",
        "  visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
        "      image_np,\n",
        "      result['detection_boxes'][0].numpy(),\n",
        "      result['detection_classes'][0].numpy().astype(int),\n",
        "      result['detection_scores'][0].numpy(),\n",
        "      category_index=category_index,\n",
        "      use_normalized_coordinates=False,\n",
        "      max_boxes_to_draw=200,\n",
        "      min_score_thresh=min_score_thresh,\n",
        "      agnostic_mode=False,\n",
        "      instance_masks=None,\n",
        "      line_thickness=4)\n",
        "  plt.imshow(image_np)\n",
        "  plt.axis('off')\n",
        "\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "object_detection.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.8"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
